import argparse
import asyncio
import json
import os
import re
from collections import Counter
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Any, Union
import aiohttp  

class SVAmpSolver:
    """SVAMP Solver with token statistics using CoT-SC (Self-Consistency)"""
    
    def __init__(self):
        self.model = "your model"  
        self.base_url = "your base_url"  
        self.token_counts = [0, 0] 
        self.stats = {
            "total_problems": 0,
            "correct_answers": 0,
            "incorrect_answers": 0,
            "accuracy": 0.0
        }
        self.num_paths = 5  
    
    async def generate(self, prompt: str) -> str:
        try:
            async with aiohttp.ClientSession() as session:
                payload = {
                    "model": self.model,
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.3,  
                    "max_tokens": 8000,
                    "top_p": 0.8
                }
                
                async with session.post(
                    f"{self.base_url}/chat/completions",
                    json=payload,
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    resp = await response.json()
                    
                    input_tokens = len(prompt) // 4
                    output_tokens = len(resp["choices"][0]["message"]["content"]) // 4
                    self.token_counts[0] += input_tokens
                    self.token_counts[1] += output_tokens
                    
                    return resp["choices"][0]["message"]["content"]
        except Exception as e:
            print(f"LLM Error: {str(e)}")
            raise
    
    async def generate_answer(self, body: str, question: str) -> List[str]:
        """Generate multiple reasoning paths from LLM"""
        prompt = f"""
Context: {body}
Question: {question}
Let's think step by step, provide the final answer in the format "Final Answer: [your answer]".
"""
        # Generate multiple responses in parallel
        tasks = [self.generate(prompt) for _ in range(self.num_paths)]
        responses = await asyncio.gather(*tasks)
        return [response.strip() for response in responses]
    
    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text and clean it to keep only digits"""
        # Try to find boxed answer first
        boxed_pattern = r'\\boxed\{([^{}]+)\}'
        boxed_matches = re.findall(boxed_pattern, text)
        if boxed_matches:
            raw_answer = boxed_matches[-1]
        else:
            # Then look for final answer line
            final_answer_match = re.search(
                r'Final\s+Answer\s*:\s*([^\n]+)', 
                text, 
                re.IGNORECASE
            )
            if final_answer_match:
                raw_answer = final_answer_match.group(1).strip()
            else:
                return None
        
        # Clean the answer - remove all non-digit characters including periods
        cleaned_answer = re.sub(r'[^\d]', '', raw_answer)
        return cleaned_answer if cleaned_answer else None
    
    def _select_answer_by_voting(self, answers: List[Optional[str]]) -> Optional[str]:
        """Select the most frequent answer from multiple reasoning paths"""
        # Filter out None answers
        valid_answers = [a for a in answers if a is not None]
        if not valid_answers:
            return None
        
        # Count answer frequencies
        answer_counts = Counter(valid_answers)
        
        # Get the most common answer
        most_common = answer_counts.most_common(1)
        return most_common[0][0] if most_common else None
    
    async def solve_problem(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Solve a single SVAMP problem using CoT-SC"""
        # Get question and context (body)
        question = problem.get("Question")
        body = problem.get("Body", "")
        if not question:
            return {"error": "No Question field found in problem"}
        
        # Reset token counts for this problem
        self.token_counts = [0, 0]
        
        # Get multiple model responses
        responses = await self.generate_answer(body, question)
        
        # Extract answers from all responses
        answers = [self._extract_answer(response) for response in responses]
        
        # Select final answer by voting
        final_answer = self._select_answer_by_voting(answers)
        
        # Get correct answer
        correct_answer = str(problem.get("Answer", "0"))
        is_correct = str(final_answer) == correct_answer if final_answer else False
        
        # Update statistics
        self.update_stats(is_correct)
        
        return {
            "problem_id": problem.get("ID", 0),
            "body": body,
            "question": question,
            "responses": responses,  
            "answers": answers, 
            "answer": final_answer, 
            "correct_answer": correct_answer,
            "is_correct": is_correct,
            "tokens": self.token_counts.copy()
        }
    
    def update_stats(self, is_correct: bool):
        """Update statistics"""
        self.stats["total_problems"] += 1
        if is_correct:
            self.stats["correct_answers"] += 1
        else:
            self.stats["incorrect_answers"] += 1
        
        if self.stats["total_problems"] > 0:
            self.stats["accuracy"] = (
                self.stats["correct_answers"] / self.stats["total_problems"] * 100
            )

async def main():
    parser = argparse.ArgumentParser(description="SVAMP Solver with CoT-SC")
    parser.add_argument("--start", type=int, default=0, help="Start index in dataset")
    parser.add_argument("--end", type=int, default=1, help="End index in dataset")
    parser.add_argument("--dataset", type=str, default="SVAMP.json", help="Path to dataset")
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs("log/SVAMP_cot_sc", exist_ok=True)
    
    solver = SVAmpSolver()
    
    # Load problems
    try:
        with open(args.dataset, "r", encoding="utf-8") as f:
            # Read each line (each problem is a separate JSON object)
            problems = [json.loads(line) for line in f if line.strip()][args.start:args.end]
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        return
    
    results = []
    
    for idx, problem in enumerate(problems, args.start):
        print(f"\n{'='*50}\nProcessing problem {idx}: {problem['Question'][:50]}...\n{'='*50}")
        
        result = await solver.solve_problem(problem)
        results.append(result)
        
        print(f"\nExecution Summary:")
        print(f"Generated answers: {result['answers']}")
        print(f"Selected answer: {result['answer']}")
        print(f"Correct answer: {result['correct_answer']}")
        print(f"Verification: {'CORRECT' if result['is_correct'] else 'INCORRECT'}")
        print(f"Tokens used: {result['tokens']}")
    
    # Save results
    if results:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"log/SVAMP_cot_sc/results_{args.start}_{args.end}_acc{solver.stats['accuracy']:.2f}%.json"
        
        output = {
            "results": results,
            "statistics": solver.stats
        }
        
        with open(filename, "w", encoding="utf-8") as f:
            json.dump(output, f, indent=2, ensure_ascii=False)
        
        print(f"\n{'='*50}\nFinal Statistics\n{'='*50}")
        print(f"Results saved to {filename}")
        print(f"Total problems processed: {solver.stats['total_problems']}")
        print(f"Correct answers: {solver.stats['correct_answers']}")
        print(f"Incorrect answers: {solver.stats['incorrect_answers']}")
        print(f"Overall accuracy: {solver.stats['accuracy']:.2f}%")
        print(f"{'='*50}\n")

if __name__ == "__main__":
    asyncio.run(main())